import os
import typing
import typing_extensions

import baml_py

from . import types, stream_types, type_builder
from .globals import DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME as __runtime__, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX as __ctx__manager__


class BamlCallOptions(typing.TypedDict, total=False):
    tb: typing_extensions.NotRequired[type_builder.TypeBuilder]
    client_registry: typing_extensions.NotRequired[baml_py.baml_py.ClientRegistry]
    env: typing_extensions.NotRequired[typing.Dict[str, typing.Optional[str]]]
    tags: typing_extensions.NotRequired[typing.Dict[str, str]]
    collector: typing_extensions.NotRequired[
        typing.Union[baml_py.baml_py.Collector, typing.List[baml_py.baml_py.Collector]]
    ]
    abort_controller: typing_extensions.NotRequired[baml_py.baml_py.AbortController]
    on_tick: typing_extensions.NotRequired[typing.Callable[[str, baml_py.baml_py.FunctionLog], None]]
    watchers: typing_extensions.NotRequired[typing.Any]  # EventCollector type, will be overridden in generated clients


class _ResolvedBamlOptions:
    tb: typing.Optional[baml_py.baml_py.TypeBuilder]
    client_registry: typing.Optional[baml_py.baml_py.ClientRegistry]
    collectors: typing.List[baml_py.baml_py.Collector]
    env_vars: typing.Dict[str, str]
    tags: typing.Dict[str, str]
    abort_controller: typing.Optional[baml_py.baml_py.AbortController]
    on_tick: typing.Optional[typing.Callable[[], None]]
    watchers: typing.Optional[typing.Any]

    def __init__(
        self,
        tb: typing.Optional[baml_py.baml_py.TypeBuilder],
        client_registry: typing.Optional[baml_py.baml_py.ClientRegistry],
        collectors: typing.List[baml_py.baml_py.Collector],
        env_vars: typing.Dict[str, str],
        tags: typing.Dict[str, str],
        abort_controller: typing.Optional[baml_py.baml_py.AbortController],
        on_tick: typing.Optional[typing.Callable[[], None]],
        watchers: typing.Optional[typing.Any],
    ):
        self.tb = tb
        self.client_registry = client_registry
        self.collectors = collectors
        self.env_vars = env_vars
        self.tags = tags
        self.abort_controller = abort_controller
        self.on_tick = on_tick
        self.watchers = watchers

{% set call_manager_name = "DoNotUseDirectlyCallManager" %}
{% set managled_resolve_fn = "__resolve" %}

class {{ call_manager_name }}:
    def __init__(self, baml_options: BamlCallOptions):
        self.__baml_options = baml_options

    def __getstate__(self):
        # Return state needed for pickling
        return {"baml_options": self.__baml_options}

    def __setstate__(self, state):
        # Restore state from pickling
        self.__baml_options = state["baml_options"]

    def __resolve(self) -> _ResolvedBamlOptions:
        tb = self.__baml_options.get("tb")
        if tb is not None:
            baml_tb = tb._tb  # type: ignore (we know how to use this private attribute)
        else:
            baml_tb = None
        client_registry = self.__baml_options.get("client_registry")
        collector = self.__baml_options.get("collector")
        collectors_as_list = (
            collector
            if isinstance(collector, list)
            else [collector] if collector is not None else []
        )
        env_vars = os.environ.copy()
        for k, v in self.__baml_options.get("env", {}).items():
            if v is not None:
                env_vars[k] = v
            else:
                env_vars.pop(k, None)

        tags = self.__baml_options.get("tags", {}) or {}

        abort_controller = self.__baml_options.get("abort_controller")

        on_tick = self.__baml_options.get("on_tick")
        if on_tick is not None:
            collector = baml_py.baml_py.Collector("on-tick-collector")
            collectors_as_list.append(collector)
            def on_tick_wrapper():
                log = collector.last
                if log is not None:
                    on_tick("Unknown", log)
        else:
            on_tick_wrapper = None

        watchers = self.__baml_options.get("watchers")

        return _ResolvedBamlOptions(
            baml_tb,
            client_registry,
            collectors_as_list,
            env_vars,
            tags,
            abort_controller,
            on_tick_wrapper,
            watchers,
        )

    def merge_options(self, options: BamlCallOptions) -> "{{call_manager_name}}":
        return DoNotUseDirectlyCallManager({**self.__baml_options, **options})

    async def call_function_async(
        self, *, function_name: str, args: typing.Dict[str, typing.Any]
    ) -> baml_py.baml_py.FunctionResult:
        resolved_options = self.__resolve()

        # Check if already aborted
        if resolved_options.abort_controller is not None and resolved_options.abort_controller.aborted:
            raise baml_py.baml_py.BamlAbortError("Operation was aborted")

        return await __runtime__.call_function(
            function_name,
            args,
            # ctx
            __ctx__manager__.clone_context(),
            # tb
            resolved_options.tb,
            # cr
            resolved_options.client_registry,
            # collectors
            resolved_options.collectors,
            # env_vars
            resolved_options.env_vars,
            # tags
            resolved_options.tags,
            # abort_controller
            resolved_options.abort_controller,
            # watchers
            resolved_options.watchers,
        )

    def call_function_sync(
        self, *, function_name: str, args: typing.Dict[str, typing.Any]
    ) -> baml_py.baml_py.FunctionResult:
        resolved_options = self.{{ managled_resolve_fn }}()

        # Check if already aborted
        if resolved_options.abort_controller is not None and resolved_options.abort_controller.aborted:
            raise baml_py.baml_py.BamlAbortError("Operation was aborted")

        ctx = __ctx__manager__.get()
        return __runtime__.call_function_sync(
            function_name,
            args,
            # ctx
            ctx,
            # tb
            resolved_options.tb,
            # cr
            resolved_options.client_registry,
            # collectors
            resolved_options.collectors,
            # env_vars
            resolved_options.env_vars,
            # tags
            resolved_options.tags,
            # abort_controller
            resolved_options.abort_controller,
            # watchers
            resolved_options.watchers,
        )

    def create_async_stream(
        self,
        *,
        function_name: str,
        args: typing.Dict[str, typing.Any],
    ) -> typing.Tuple[baml_py.baml_py.RuntimeContextManager, baml_py.baml_py.FunctionResultStream]:
        resolved_options = self.{{ managled_resolve_fn }}()
        ctx = __ctx__manager__.clone_context()
        result = __runtime__.stream_function(
            function_name,
            args,
            # this is always None, we set this later!
            # on_event
            None,
            # ctx
            ctx,
            # tb
            resolved_options.tb,
            # cr
            resolved_options.client_registry,
            # collectors
            resolved_options.collectors,
            # env_vars
            resolved_options.env_vars,
            # tags
            resolved_options.tags,
            # on_tick
            resolved_options.on_tick,
            # abort_controller
            resolved_options.abort_controller,
        )
        return ctx, result

    def create_sync_stream(
        self,
        *,
        function_name: str,
        args: typing.Dict[str, typing.Any],
    ) -> typing.Tuple[baml_py.baml_py.RuntimeContextManager, baml_py.baml_py.SyncFunctionResultStream]:
        resolved_options = self.{{ managled_resolve_fn }}()
        if resolved_options.on_tick is not None:
            raise ValueError("on_tick is not supported for sync streams. Please use async streams instead.")
        ctx = __ctx__manager__.get()
        result = __runtime__.stream_function_sync(
            function_name,
            args,
            # this is always None, we set this later!
            # on_event
            None,
            # ctx
            ctx,
            # tb
            resolved_options.tb,
            # cr
            resolved_options.client_registry,
            # collectors
            resolved_options.collectors,
            # env_vars
            resolved_options.env_vars,
            # tags
            resolved_options.tags,
            # on_tick
            # always None! sync streams don't support on_tick
            None,
            # abort_controller
            resolved_options.abort_controller,
        )
        return ctx, result

    async def create_http_request_async(
        self,
        *,
        function_name: str,
        args: typing.Dict[str, typing.Any],
        mode: typing_extensions.Literal["stream", "request"],
    ) -> baml_py.baml_py.HTTPRequest:
        resolved_options = self.{{ managled_resolve_fn }}()
        return await __runtime__.build_request(
            function_name,
            args,
            # ctx
            __ctx__manager__.clone_context(),
            # tb
            resolved_options.tb,
            # cr
            resolved_options.client_registry,
            # env_vars
            resolved_options.env_vars,
            # is_stream
            mode == "stream",
        )

    def create_http_request_sync(
        self,
        *,
        function_name: str,
        args: typing.Dict[str, typing.Any],
        mode: typing_extensions.Literal["stream", "request"],
    ) -> baml_py.baml_py.HTTPRequest:
        resolved_options = self.{{ managled_resolve_fn }}()
        return __runtime__.build_request_sync(
            function_name,
            args,
            # ctx
            __ctx__manager__.get(),
            # tb
            resolved_options.tb,
            # cr
            resolved_options.client_registry,
            # env_vars
            resolved_options.env_vars,
            # is_stream
            mode == "stream",
        )

    def parse_response(self, *, function_name: str, llm_response: str, mode: typing_extensions.Literal["stream", "request"]) -> typing.Any:
        resolved_options = self.{{ managled_resolve_fn }}()
        return __runtime__.parse_llm_response(
            function_name,
            llm_response,
            # enum_module
            types,
            # cls_module
            types,
            # partial_cls_module
            stream_types,
            # allow_partials
            mode == "stream",
            # ctx
            __ctx__manager__.get(),
            # tb
            resolved_options.tb,
            # cr
            resolved_options.client_registry,
            # env_vars
            resolved_options.env_vars,
        )


def disassemble(function: typing.Callable) -> None:
    import inspect
    from . import b

    if not callable(function):
        print(f"disassemble: object {function} is not a Baml function")
        return

    is_client_method = False

    for (method_name, _) in inspect.getmembers(b, predicate=inspect.ismethod):
        if method_name == function.__name__:
            is_client_method = True
            break

    if not is_client_method:
        print(f"disassemble: function {function.__name__} is not a Baml function")
        return

    print(f"----- function {function.__name__} -----")
    __runtime__.disassemble(function.__name__)
